import gym
import gym.error
import numpy as np
from gym.spaces import Box
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import grl.envs.mujoco_rarl
import scipy.stats
from .torch_utils import RunningStat, ZFilter, Identity, StateWithTime, RewardFilter
import random


class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        if max_episode_steps is None and self.env.spec is not None:
            max_episode_steps = env.spec.max_episode_steps
        if self.env.spec is not None:
            self.env.spec.max_episode_steps = max_episode_steps
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = None

        self.observation_space = Box(low=np.asarray([*self.env.observation_space.low, 0.0]),
                                     high=np.asarray([*self.env.observation_space.high, 1.0]),
                                     dtype=np.float32)

    def _add_timestep_to_obs(self, obs):
        return np.concatenate((obs, [self._elapsed_steps / self._max_episode_steps]))

    def step(self, action):
        assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()"
        observation, reward, done, info = self.env.step(action)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            info['TimeLimit.truncated'] = not done
            done = True

        return self._add_timestep_to_obs(observation), reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self._add_timestep_to_obs(self.env.reset(**kwargs))


class MujocoAdvEnv(MultiAgentEnv):
    def __init__(self, env_config=None):

        self.base_gym_name = env_config.get("base_gym_name", "Hopper-v2")
        gym_spec = gym.spec(self.base_gym_name)

        max_episode_length = int(env_config.get("max_episode_length", 2048))
        gym_env = TimeLimit(gym_spec.make(), max_episode_steps=max_episode_length)

        self._env = gym_env

        self._default_protag_reward_this_episode = None

        self.standard_agent_id = 0
        self.adversary_agent_id = 1

        self._acting_agent_idx = 0

        # Good discussion of how action_space and observation_space work in MultiAgentEnv:
        # https://github.com/ray-project/ray/issues/6875
        # These properties below are not part of the MultiAgentEnv parent class.
        # RLLib doesn't observe these member variables in a multi agent env.
        # We just need to pass in the correct obs/action spaces to the configs for each policy.

        # adversary_action_space = Box(low=-1.0, high=1.0, shape=self._env.observation_space.shape, dtype=np.float32)
        # need to add a config for determining the attack method

        self.action_space = {
            "adversary": Box(low=-1.0, high=1.0, shape=self._env.action_space.shape, dtype=np.float32),
            "standard_agent": Box(low=-1.0, high=1.0, shape=self._env.action_space.shape, dtype=np.float32),
        }

        self.observation_space = {
            "adversary": self._env.observation_space,
            "standard_agent_partial": self._env.observation_space,
        }
        
        #observation and reward wrapper
        self.state_filter = Identity()
        self.norm_states = env_config["norm_states"]
        self.clip_obs = env_config["clip_obs"]
        self.norm_rewards = env_config["norm_rewards"]
        self.clip_rew = env_config["clip_rew"]
        self.gamma = env_config["gamma"]

        if self.norm_states:
            self.state_filter = ZFilter(self.state_filter, shape=[self._env.observation_space.shape[0]], \
                                            clip=self.clip_obs)
        
        if self.norm_rewards == "returns":
            self.reward_filter = RewardFilter(self.reward_filter, shape=(), gamma=self.gamma, clip=self.clip_rew)

        # Running total reward (set to 0.0 at resets)
        self.total_true_reward = 0.0

    def reset(self):
        """Resets the env and returns observations from ready agents.

        Returns:
            obs (dict): New observations for each ready agent.
        """
        # Set a deterministic random seed for reproduicability
        self._env.seed(random.getrandbits(31))
        # Reset the state, and the running total reward
        start_state = self._env.reset()
        self.total_true_reward = 0.0
        self.state_filter.reset()
        self.reward_filter.reset()

        obs = self.state_filter(start_state, reset=True)
        obs_dict = {self.standard_agent_id: obs}
        obs_dict[self.adversary_agent_id] = obs
        assert obs is not None

        return obs_dict

    def step(self, action_dict):
        """Returns observations from ready agents.

        The returns are dicts mapping from agent_id strings to values. The
        number of agents in the env can vary over time.

        Returns
        -------
            obs (dict): New observations for each ready agent.
            rewards (dict): Reward values for each ready agent. If the
                episode is just started, the value will be None.
            dones (dict): Done values for each ready agent. The special key
                "__all__" (required) is used to indicate env termination.
            infos (dict): Optional info values for each agent id.
        """
        obs = {}
        rews = {}
        dones = {"__all__": False}
        infos = {}

        observation, rew, done, _ = self._env.step(action_dict[self.standard_agent_id])

        observation = self.state_filter(observation)

        obs[self.standard_agent_id] = observation
        rews[self.standard_agent_id] = rew  # base env reward
        dones[self.standard_agent_id] = done
        infos[self.standard_agent_id] = {}

        self.total_true_reward += rew

        if self.adversary_agent_id is not None:
            obs[self.adversary_agent_id] = observation
            rews[self.adversary_agent_id] = -rew
            dones[self.adversary_agent_id] = done
            infos[self.adversary_agent_id] = {}

        if done:
            infos[self.standard_agent_id] = {"full_episode_completed": True,
                                             "total_rewards": self.total_true_reward
                                          }

        return obs, rews, dones, infos

    def render(self, mode="human"):
        return self._env.render(mode=mode)


